{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <div style='color:#fe618e;font-weight:800;'>最最简单的超级马里奥训练过程</div>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nes_py.wrappers import JoypadSpace\n",
    "import gym_super_mario_bros\n",
    "from gym_super_mario_bros.actions import SIMPLE_MOVEMENT\n",
    "import time\n",
    "from matplotlib import pyplot as plt\n",
    "from stable_baselines3 import PPO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym_super_mario_bros.make('SuperMarioBros-v0')\n",
    "env = JoypadSpace(env, SIMPLE_MOVEMENT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cuda device\n",
      "Wrapping the env with a `Monitor` wrapper\n",
      "Wrapping the env in a DummyVecEnv.\n",
      "Wrapping the env in a VecTransposeImage.\n",
      "Logging to ./tensorboard_log/PPO_1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\software\\e_anaconda\\envs\\pytorch\\lib\\site-packages\\gym_super_mario_bros\\smb_env.py:148: RuntimeWarning: overflow encountered in ubyte_scalars\n",
      "  return (self.ram[0x86] - self.ram[0x071c]) % 256\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------------------------\n",
      "| time/              |      |\n",
      "|    fps             | 116  |\n",
      "|    iterations      | 1    |\n",
      "|    time_elapsed    | 17   |\n",
      "|    total_timesteps | 2048 |\n",
      "-----------------------------\n",
      "-----------------------------------------\n",
      "| time/                   |             |\n",
      "|    fps                  | 81          |\n",
      "|    iterations           | 2           |\n",
      "|    time_elapsed         | 50          |\n",
      "|    total_timesteps      | 4096        |\n",
      "| train/                  |             |\n",
      "|    approx_kl            | 0.025405666 |\n",
      "|    clip_fraction        | 0.274       |\n",
      "|    clip_range           | 0.2         |\n",
      "|    entropy_loss         | -1.92       |\n",
      "|    explained_variance   | 0.00504     |\n",
      "|    learning_rate        | 0.0003      |\n",
      "|    loss                 | 0.621       |\n",
      "|    n_updates            | 10          |\n",
      "|    policy_gradient_loss | 0.0109      |\n",
      "|    value_loss           | 17.4        |\n",
      "-----------------------------------------\n",
      "-----------------------------------------\n",
      "| time/                   |             |\n",
      "|    fps                  | 73          |\n",
      "|    iterations           | 3           |\n",
      "|    time_elapsed         | 83          |\n",
      "|    total_timesteps      | 6144        |\n",
      "| train/                  |             |\n",
      "|    approx_kl            | 0.010906073 |\n",
      "|    clip_fraction        | 0.109       |\n",
      "|    clip_range           | 0.2         |\n",
      "|    entropy_loss         | -1.92       |\n",
      "|    explained_variance   | 0.0211      |\n",
      "|    learning_rate        | 0.0003      |\n",
      "|    loss                 | 0.101       |\n",
      "|    n_updates            | 20          |\n",
      "|    policy_gradient_loss | -0.00392    |\n",
      "|    value_loss           | 0.187       |\n",
      "-----------------------------------------\n",
      "-----------------------------------------\n",
      "| time/                   |             |\n",
      "|    fps                  | 69          |\n",
      "|    iterations           | 4           |\n",
      "|    time_elapsed         | 117         |\n",
      "|    total_timesteps      | 8192        |\n",
      "| train/                  |             |\n",
      "|    approx_kl            | 0.009882288 |\n",
      "|    clip_fraction        | 0.0681      |\n",
      "|    clip_range           | 0.2         |\n",
      "|    entropy_loss         | -1.9        |\n",
      "|    explained_variance   | 0.101       |\n",
      "|    learning_rate        | 0.0003      |\n",
      "|    loss                 | 0.0738      |\n",
      "|    n_updates            | 30          |\n",
      "|    policy_gradient_loss | -0.00502    |\n",
      "|    value_loss           | 0.13        |\n",
      "-----------------------------------------\n",
      "-----------------------------------------\n",
      "| rollout/                |             |\n",
      "|    ep_len_mean          | 1.01e+04    |\n",
      "|    ep_rew_mean          | 891         |\n",
      "| time/                   |             |\n",
      "|    fps                  | 65          |\n",
      "|    iterations           | 5           |\n",
      "|    time_elapsed         | 156         |\n",
      "|    total_timesteps      | 10240       |\n",
      "| train/                  |             |\n",
      "|    approx_kl            | 0.008186281 |\n",
      "|    clip_fraction        | 0.105       |\n",
      "|    clip_range           | 0.2         |\n",
      "|    entropy_loss         | -1.87       |\n",
      "|    explained_variance   | 0.0161      |\n",
      "|    learning_rate        | 0.0003      |\n",
      "|    loss                 | 0.28        |\n",
      "|    n_updates            | 40          |\n",
      "|    policy_gradient_loss | -0.00649    |\n",
      "|    value_loss           | 0.811       |\n",
      "-----------------------------------------\n",
      "-----------------------------------------\n",
      "| rollout/                |             |\n",
      "|    ep_len_mean          | 1.01e+04    |\n",
      "|    ep_rew_mean          | 891         |\n",
      "| time/                   |             |\n",
      "|    fps                  | 64          |\n",
      "|    iterations           | 6           |\n",
      "|    time_elapsed         | 190         |\n",
      "|    total_timesteps      | 12288       |\n",
      "| train/                  |             |\n",
      "|    approx_kl            | 0.024062362 |\n",
      "|    clip_fraction        | 0.246       |\n",
      "|    clip_range           | 0.2         |\n",
      "|    entropy_loss         | -1.9        |\n",
      "|    explained_variance   | 0.269       |\n",
      "|    learning_rate        | 0.0003      |\n",
      "|    loss                 | 0.54        |\n",
      "|    n_updates            | 50          |\n",
      "|    policy_gradient_loss | 0.0362      |\n",
      "|    value_loss           | 10.8        |\n",
      "-----------------------------------------\n",
      "-----------------------------------------\n",
      "| rollout/                |             |\n",
      "|    ep_len_mean          | 1.01e+04    |\n",
      "|    ep_rew_mean          | 891         |\n",
      "| time/                   |             |\n",
      "|    fps                  | 63          |\n",
      "|    iterations           | 7           |\n",
      "|    time_elapsed         | 225         |\n",
      "|    total_timesteps      | 14336       |\n",
      "| train/                  |             |\n",
      "|    approx_kl            | 0.024466533 |\n",
      "|    clip_fraction        | 0.211       |\n",
      "|    clip_range           | 0.2         |\n",
      "|    entropy_loss         | -1.89       |\n",
      "|    explained_variance   | 0.839       |\n",
      "|    learning_rate        | 0.0003      |\n",
      "|    loss                 | 0.435       |\n",
      "|    n_updates            | 60          |\n",
      "|    policy_gradient_loss | 0.023       |\n",
      "|    value_loss           | 3.06        |\n",
      "-----------------------------------------\n",
      "----------------------------------------\n",
      "| rollout/                |            |\n",
      "|    ep_len_mean          | 1.01e+04   |\n",
      "|    ep_rew_mean          | 891        |\n",
      "| time/                   |            |\n",
      "|    fps                  | 63         |\n",
      "|    iterations           | 8          |\n",
      "|    time_elapsed         | 259        |\n",
      "|    total_timesteps      | 16384      |\n",
      "| train/                  |            |\n",
      "|    approx_kl            | 0.01970315 |\n",
      "|    clip_fraction        | 0.242      |\n",
      "|    clip_range           | 0.2        |\n",
      "|    entropy_loss         | -1.9       |\n",
      "|    explained_variance   | 0.486      |\n",
      "|    learning_rate        | 0.0003     |\n",
      "|    loss                 | 0.526      |\n",
      "|    n_updates            | 70         |\n",
      "|    policy_gradient_loss | 0.00486    |\n",
      "|    value_loss           | 1.57       |\n",
      "----------------------------------------\n",
      "-----------------------------------------\n",
      "| rollout/                |             |\n",
      "|    ep_len_mean          | 1.01e+04    |\n",
      "|    ep_rew_mean          | 891         |\n",
      "| time/                   |             |\n",
      "|    fps                  | 62          |\n",
      "|    iterations           | 9           |\n",
      "|    time_elapsed         | 293         |\n",
      "|    total_timesteps      | 18432       |\n",
      "| train/                  |             |\n",
      "|    approx_kl            | 0.012460884 |\n",
      "|    clip_fraction        | 0.217       |\n",
      "|    clip_range           | 0.2         |\n",
      "|    entropy_loss         | -1.87       |\n",
      "|    explained_variance   | 0.74        |\n",
      "|    learning_rate        | 0.0003      |\n",
      "|    loss                 | 0.139       |\n",
      "|    n_updates            | 80          |\n",
      "|    policy_gradient_loss | -0.000311   |\n",
      "|    value_loss           | 0.734       |\n",
      "-----------------------------------------\n",
      "----------------------------------------\n",
      "| rollout/                |            |\n",
      "|    ep_len_mean          | 1.01e+04   |\n",
      "|    ep_rew_mean          | 891        |\n",
      "| time/                   |            |\n",
      "|    fps                  | 62         |\n",
      "|    iterations           | 10         |\n",
      "|    time_elapsed         | 327        |\n",
      "|    total_timesteps      | 20480      |\n",
      "| train/                  |            |\n",
      "|    approx_kl            | 0.02535792 |\n",
      "|    clip_fraction        | 0.298      |\n",
      "|    clip_range           | 0.2        |\n",
      "|    entropy_loss         | -1.88      |\n",
      "|    explained_variance   | 0.405      |\n",
      "|    learning_rate        | 0.0003     |\n",
      "|    loss                 | 1.17       |\n",
      "|    n_updates            | 90         |\n",
      "|    policy_gradient_loss | 0.0205     |\n",
      "|    value_loss           | 6.6        |\n",
      "----------------------------------------\n",
      "-----------------------------------------\n",
      "| rollout/                |             |\n",
      "|    ep_len_mean          | 1.01e+04    |\n",
      "|    ep_rew_mean          | 891         |\n",
      "| time/                   |             |\n",
      "|    fps                  | 62          |\n",
      "|    iterations           | 11          |\n",
      "|    time_elapsed         | 361         |\n",
      "|    total_timesteps      | 22528       |\n",
      "| train/                  |             |\n",
      "|    approx_kl            | 0.019694094 |\n",
      "|    clip_fraction        | 0.243       |\n",
      "|    clip_range           | 0.2         |\n",
      "|    entropy_loss         | -1.91       |\n",
      "|    explained_variance   | 0.952       |\n",
      "|    learning_rate        | 0.0003      |\n",
      "|    loss                 | 0.39        |\n",
      "|    n_updates            | 100         |\n",
      "|    policy_gradient_loss | -0.00434    |\n",
      "|    value_loss           | 1.31        |\n",
      "-----------------------------------------\n",
      "-----------------------------------------\n",
      "| rollout/                |             |\n",
      "|    ep_len_mean          | 1.19e+04    |\n",
      "|    ep_rew_mean          | 884         |\n",
      "| time/                   |             |\n",
      "|    fps                  | 61          |\n",
      "|    iterations           | 12          |\n",
      "|    time_elapsed         | 398         |\n",
      "|    total_timesteps      | 24576       |\n",
      "| train/                  |             |\n",
      "|    approx_kl            | 0.013096321 |\n",
      "|    clip_fraction        | 0.227       |\n",
      "|    clip_range           | 0.2         |\n",
      "|    entropy_loss         | -1.91       |\n",
      "|    explained_variance   | 0.0132      |\n",
      "|    learning_rate        | 0.0003      |\n",
      "|    loss                 | 0.669       |\n",
      "|    n_updates            | 110         |\n",
      "|    policy_gradient_loss | -0.000837   |\n",
      "|    value_loss           | 1.42        |\n",
      "-----------------------------------------\n",
      "-----------------------------------------\n",
      "| rollout/                |             |\n",
      "|    ep_len_mean          | 1.19e+04    |\n",
      "|    ep_rew_mean          | 884         |\n",
      "| time/                   |             |\n",
      "|    fps                  | 61          |\n",
      "|    iterations           | 13          |\n",
      "|    time_elapsed         | 432         |\n",
      "|    total_timesteps      | 26624       |\n",
      "| train/                  |             |\n",
      "|    approx_kl            | 0.014833134 |\n",
      "|    clip_fraction        | 0.239       |\n",
      "|    clip_range           | 0.2         |\n",
      "|    entropy_loss         | -1.9        |\n",
      "|    explained_variance   | 0.452       |\n",
      "|    learning_rate        | 0.0003      |\n",
      "|    loss                 | 18.1        |\n",
      "|    n_updates            | 120         |\n",
      "|    policy_gradient_loss | -7.3e-05    |\n",
      "|    value_loss           | 26.3        |\n",
      "-----------------------------------------\n"
     ]
    }
   ],
   "source": [
    "tensorboard_log = r'./tensorboard_log/'\n",
    "\n",
    "model = PPO(\"CnnPolicy\", env, verbose=1,\n",
    "            tensorboard_log = tensorboard_log)\n",
    "model.learn(total_timesteps=25000)\n",
    "model.save(\"mario_model\")"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "9465aae7e0ab1403d672807d1a0963d86dbda2f584fbe3054c36cf78311c6c77"
  },
  "kernelspec": {
   "display_name": "Python 3.8.11 ('pytorch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
